-
Notifications
You must be signed in to change notification settings - Fork 812
Add top-p and top-k sampling to GenerationUtils #2137
Add top-p and top-k sampling to GenerationUtils #2137
Conversation
ac3439b to
37627f7
Compare
|
You have successfully added a new CodeQL configuration |
2b904f2 to
8331b24
Compare
8331b24 to
1361a14
Compare
|
High level, I notice you implemented the sampling as part of the
Thoughts? @yohann-benchetrit |
torchtext/prototype/generate.py
Outdated
| log_probs = F.log_softmax(decoder_output[:, -1], dim=-1) | ||
|
|
||
| if do_sample: | ||
| probs = log_probs.softmax(dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Taking double softmax here? Probs are already softmax'd on L91
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for spotting this ! It should be a .exp to retrieve the original probabilities indeed.
|
Will want to add temperature, but can do that in a follow-up PR. Tracking issue: #2138 |
I agree with @joecummings here in that by allowing users to provide An alternative proposal here could be to keep the |
@Nayef211 We abstract away the internal methods, so high-level users should never have to call |
Agreed, thanks @joecummings and @Nayef211 for your comments ! So with this additional information my understanding is:
My thoughts on this:
|
I like this idea more and unless there are customers that are asking for sampling to be implemented in |
1e98f39 to
d1fd6ff
Compare
|
Thanks again for your comments. I addressed them as follows:
|
56434b7 to
8e1e907
Compare
Address PR comments and add Temperature
8e1e907 to
1fc72a2
Compare
joecummings
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Thanks @yohann-benchetrit !
GenerationUtils.generate:top-ptop-kremove_invalid_valuestemperature(see Issue 2138)